package com.xunmall.gateway.interceptor.security;

import com.xunmall.base.util.StringUtils;
import com.xunmall.base.util.encode.EncodeUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.assertj.core.util.Lists;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.util.*;

/**
 * Created by Gimgoog on 2018/1/29.
 */
@Slf4j
public class XssEncoderHttpRequest extends HttpServletRequestWrapper {

    private Map<String, String[]> allParameters = new TreeMap<String, String[]>();
    private final static String BIN_ENCODING = "ISO-8859-1";
    private String encoding = null;
    private byte[] body = null;
    private int lengthChange = 0;
    private String queryString = "";

    /**
     * Create a new request wrapper that will merge additional parameters into
     * the request object without prematurely reading parameters from the
     * original request.
     * @param request
     * @param additionalParams
     */
    public XssEncoderHttpRequest(final HttpServletRequest request, final Map<String, String[]> additionalParams) {
        super(request);

        allParameters.putAll(super.getParameterMap());
        if(additionalParams != null) {
            allParameters.putAll(additionalParams);
        }

        try {
            encoding = request.getCharacterEncoding();
            xssEncodeOper(request);
        } catch (Exception e) {
            log.error("xssEncode出错!", e);
        }
    }

    @Override
    public String getParameter(final String name) {
        String[] strings = getParameterMap().get(name);
        if (strings != null) {
            return strings[0];
        }
        return super.getParameter(name);
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        //Return an unmodifiable collection because we need to uphold the interface contract.
        return Collections.unmodifiableMap(allParameters);
    }

    @Override
    public Enumeration<String> getParameterNames() {
        return Collections.enumeration(getParameterMap().keySet());
    }

    @Override
    public String[] getParameterValues(final String name) {
        return getParameterMap().get(name);
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        if(body == null) {
            return super.getInputStream();
        } else {
            return new ServletInputStreamImpl(new ByteArrayInputStream(body));
        }
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream(), encoding));
    }

    @Override
    public int getContentLength() {
        return super.getContentLength() + lengthChange;
    }

    @Override
    public String getQueryString(){
        if(StringUtils.isEmpty(this.queryString)) {
            String qString = super.getQueryString();
            if (qString == null) {
                return null;
            }
            StringTokenizer st = new StringTokenizer(qString, "&");
            List<String> encodeList = Lists.newArrayList();

            int i;
            while (st.hasMoreTokens()) {
                String s = st.nextToken();
                i = s.indexOf("=");
                if (i > 0 && s.length() >= i + 1) {
                    String name = s.substring(0, i);
                    String value = s.substring(i + 1);
                    String encodeValue = "";

                    try {
//						name = URLDecoder.decode(name, "UTF-8");
                    } catch (Exception e) {
                    }
                    try {
                        value = URLDecoder.decode(value, "UTF-8");
                        encodeValue = EncodeUtils.xssEncode(value);
                        encodeValue = URLEncoder.encode(encodeValue,"UTF-8");
                    } catch (Exception e) {
                    }
                    encodeList.add(name + "=" + encodeValue);
                }
            }
            this.queryString = String.join("&", encodeList);
        }
        return this.queryString;
    }

    // =====================================================================================

    public void addParameter(String name, String value) {
        allParameters.put(name, new String[]{value});
    }

    @SuppressWarnings("rawtypes")
    private void xssEncodeOper(final HttpServletRequest req) throws Exception{

        // 将传入参数中的xss危险字符进行转码
        String contentType = req.getContentType();
        if (contentType == null) {
            contentType = "";
        }
        if (contentType.startsWith("multipart")) {	// 文件上传form
            InputStream is = super.getInputStream();
            body = IOUtils.toByteArray(is);
            int startLength = body.length;
            String oo = new String(body, BIN_ENCODING);
            StringBuilder sb = new StringBuilder(oo);
            String boundary = "--" + contentType.substring(contentType.indexOf("boundary=") + 9,contentType.length()).trim();
//			System.out.println("boundary = " + boundary);

            String[] fields = oo.split(boundary);
            for(int i = 0;i < fields.length;i++){
                String field = fields[i];
                if(!field.contains("filename=")){
                    int sIndex = field.indexOf("\r\n\r\n");
                    if(sIndex >= 0){
                        String sub = field.substring(sIndex);
                        String newsub = new String(sub.getBytes(BIN_ENCODING), encoding);
                        String esub = new String(EncodeUtils.xssEncode(newsub).getBytes(encoding), BIN_ENCODING);

                        int fIndex = sb.indexOf(field);
                        if (fIndex >= 0) {
                            int start = sb.indexOf(sub, fIndex);
                            int end = start + sub.length();
                            if (start >= 0) {
                                sb.replace(start, end, esub);
                            }
                        }
                    }
                }
            }
//			System.out.println(oo);
//			System.out.println("=====================================");
//			System.out.println(sb.toString());
//			byte[] body2 = sb.toString().getBytes(BIN_ENCODING);
//			boolean flag = true;
//			for (int i=0; i< body.length; i++)
//	            if (body[i] != body2[i]) flag = false;
//			System.out.println("长度:" + body2.length + ":" + body.length);
//			System.out.println("String:" + sb.toString().equals(oo));
//			System.out.println("byte：" + flag);
            body = sb.toString().getBytes(BIN_ENCODING);
            int endLength = body.length;
            lengthChange = endLength - startLength;
        } else {	// 正常传参
            // params
            Enumeration params = req.getParameterNames();
            while (params.hasMoreElements()) {
                String pName = (String) params.nextElement();
//				System.out.println(pName);
                String pValue = req.getParameter(pName);
                if (!StringUtils.isEmpty(pValue)) {
                    String tmp = EncodeUtils.xssEncode(pValue);
                    allParameters.put(pName, new String[] { tmp });
                }
            }
            // body
//			if(contentType.startsWith("text/plain")|| contentType.startsWith("application/json")) {
            InputStream is = super.getInputStream();
            body = IOUtils.toByteArray(is);
            int startLength = body.length;
            String oo = new String(body, BIN_ENCODING);
            String tmp = EncodeUtils.xssEncode(new String(oo.getBytes(BIN_ENCODING), encoding));
            body = new String(tmp.getBytes(encoding), BIN_ENCODING).getBytes(BIN_ENCODING);
            int endLength = body.length;
            lengthChange = endLength - startLength;
//			}
        }
    }

    private class ServletInputStreamImpl extends ServletInputStream {

        private ByteArrayInputStream is;

        public ServletInputStreamImpl(InputStream is) {
            this.is = (ByteArrayInputStream) is;
        }

        @Override
        public int read() throws IOException {
            return is.read();
        }

        @Override
        public boolean isFinished() {
            return is.available() == 0;
        }

        @Override
        public boolean isReady() {
            return true;
        }

        @Override
        public void setReadListener(ReadListener listener) {
            throw new RuntimeException("Not implemented");
        }
    }
}
